#!/usr/bin/env python3
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2020  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#
#
"""
Calculates forces and lattice derivatives with finite differences.
"""
import os
import subprocess
import argparse
import re
import numpy as np
import numpy.linalg as la

BOHR__AA = 0.529177249
AA__BOHR = 1.0 / BOHR__AA

DESCRIPTION = """
Calculates the forces using the specified DFTB+ binary by finite differences
displacing the atoms along every axis. The geometry of the configuration must
be specified in a file called 'geo.gen.template'. The DFTB+ input file should
include the geometry from the file 'geo.gen'. The input file must specify the
option for writing an autotest file and if the reference should be calculated
then also the option for calculating the forces.
"""


ENERGY_PATTERN = re.compile(r"mermin_energy[^:]*:[^:]*:[^:]*:\s*(?P<value>\S+)")

FORCES_PATTERN = re.compile(
    r"forces[^:]*:[^:]*:[^:]*:\d+,\d+\s*"
    r"(?P<values>(?:\s*[+-]?\d+(?:\.\d+(?:E[+-]?\d+)?)?)+)",
    re.MULTILINE)

LATTICE_DERIV_PATTERN = re.compile(
    r"^stress[^:]*:[^:]*:[^:]*:\d+,\d+\s*"
    r"(?P<values>(?:\s*[+-]?\d+(?:\.\d+(?:E[+-]?\d+)?)?)+)", re.MULTILINE)

REFERENCE_AUTOTEST = 'autotest0.tag'


def main():
    """Main routine"""

    args = parse_arguments()
    specienames, species, coords, origin, latvecs = readgen("geo.gen.template")
    calcforces = not args.skipforces
    calclatderivs = (latvecs is not None) and not args.skiplattice
    if args.calcref:
        reffile = REFERENCE_AUTOTEST
    elif args.ref is not None:
        reffile = args.ref
    else:
        reffile = None
    disp = args.disp * AA__BOHR
    binary = args.binary
    print("BINARY:", binary)

    if args.calcref:
        calculate_reference(binary, reffile, coords, specienames, species,
                            origin, latvecs)

    if reffile is not None:
        forces0, latderivs0 = read_reference_results(reffile, calcforces,
                                                     calclatderivs, latvecs)
    else:
        forces0 = latderivs0 = None

    if calcforces:
        forces = calculate_forces(
            binary, disp, coords, specienames, species, origin, latvecs)

    if latvecs is not None and calclatderivs:
        latderivs = calculate_latderivs(
            binary, disp, coords, latvecs, specienames, species, origin)

    if calcforces:
        print_forces(forces, forces0)

    if latvecs is not None and calclatderivs:
        print_latderivs(latderivs, latderivs0)


def parse_arguments():
    """Parses command line arguments"""

    parser = argparse.ArgumentParser(description=DESCRIPTION)

    msg = "Specify the displacement of the atoms (unit: ANGSTROM)"
    parser.add_argument("-d", "--displacement", type=float, dest="disp",
                        default=1e-5, help=msg)

    msg = "Compare derivatives with those in reference autotest file"
    parser.add_argument("-r", "--reference", dest="ref", help=msg)

    msg = "Calculate reference system (and compare derivatives with it),"\
          " resulting autotest file will be saved as 'autotest0.tag'"
    parser.add_argument("-c", "--calc-reference", dest="calcref",
                        action="store_true", default=False, help=msg)

    msg = "Skip the calculation of the lattice derivatives in the case of"\
          " periodic systems"
    parser.add_argument("-L", "--skip-lattice", dest="skiplattice",
                        action="store_true", default=False, help=msg)

    msg = "Skip the calculation of forces (useful when only lattice"\
          " derivatives should be calculated)"
    parser.add_argument("-F", "--skip-forces", dest="skipforces",
                        action="store_true", default=False, help=msg)

    msg = "DFTB+ binary"
    parser.add_argument("binary", help=msg)

    args = parser.parse_args()
    if args.ref is not None and args.calcref:
        msg = "Specifying a reference file and requesting a calculation of the"\
              " reference system are mutually exclusive options"
        parser.error(msg)
    return args


def read_reference_results(autotest0, calcforces, calclatderivs, latvecs):
    """Reads in reference results"""

    forces0 = None
    latderivs0 = None
    fp = open(autotest0, "r")
    txt = fp.read()
    fp.close()
    if calcforces:
        match = FORCES_PATTERN.search(txt)
        if match:
            tmp = np.fromstring(match.group("values"), count=-1, dtype=float,
                                sep=" ")
            forces0 = tmp.reshape((-1, 3))
        else:
            raise "No forces found in reference file!"
    if calclatderivs:
        match = LATTICE_DERIV_PATTERN.search(txt)
        if match:
            tmp = np.fromstring(match.group("values"), count=-1, dtype=float,
                                sep=" ")
            stress0 = tmp.reshape((-1, 3))
            latderivs0 = stress2latderivs(stress0, latvecs)
        else:
            raise "No lattice derivatives found in reference file!"
    return forces0, latderivs0


def calculate_reference(binary, reffile, coords, specienames, species, origin,
                        latvecs):
    """Calculates reference system"""

    writegen("geo.gen", (specienames, species, coords, origin, latvecs))
    subprocess.call([binary])
    os.rename("autotest.tag", reffile)


def calculate_forces(binary, disp, coords, specienames, species, origin,
                     latvecs):
    """Calculates forces by finite differences"""

    energy = np.empty((2,), dtype=float)
    forces = np.empty((len(coords), 3), dtype=float)
    for iat in range(len(coords)):
        for ii in range(3):
            for jj in range(2):
                newcoords = np.array(coords)
                newcoords[iat][ii] += float(2 * jj - 1) * disp
                writegen("geo.gen", (specienames, species, newcoords, origin,
                                     latvecs))
                subprocess.call([binary])
                fp = open("autotest.tag", "r")
                txt = fp.read()
                fp.close()
                match = ENERGY_PATTERN.search(txt)
                print("iat: %2d, ii: %2d, jj: %2d" % (iat, ii, jj))
                if match:
                    energy[jj] = float(match.group("value"))
                    print("energy:", energy[jj])
                else:
                    raise "No match found!"
            forces[iat][ii] = (energy[0] - energy[1]) / (2.0 * disp)
    return forces


def calculate_latderivs(binary, disp, coords, latvecs, specienames, species,
                        origin):
    """Calculates lattice derivatives by finite differences"""

    energy = np.empty((2,), dtype=float)
    latderivs = np.empty((3, 3), dtype=float)
    for ii in range(3):
        for jj in range(3):
            for kk in range(2):
                newcoords = np.array(coords)
                newcoords = cart2frac(latvecs, newcoords)
                newvecs = np.array(latvecs)
                newvecs[jj][ii] += float(2 * kk - 1) * disp
                newcoords = frac2cart(newvecs, newcoords)
                writegen("geo.gen", (specienames, species, newcoords, origin,
                                     newvecs))
                subprocess.call([binary,])
                fp = open("autotest.tag", "r")
                txt = fp.read()
                fp.close()
                match = ENERGY_PATTERN.search(txt)
                print("ii: %2d, jj: %2d, kk: %2d" % (ii, jj, kk))
                print("energy:", energy[kk])
                if match:
                    energy[kk] = float(match.group("value"))
                else:
                    raise "No match found!"
            latderivs[jj][ii] = (energy[1] - energy[0]) / (2.0 * disp)
    return latderivs


def print_forces(forces, forces0):
    """Prints calculates forces"""

    print("Forces by finite differences:")
    for iat, atforce in enumerate(forces):
        print(("%25.12E" * 3) % tuple(atforce))
    if forces0 is not None:
        print("Reference forces:")
        for iat, atforce in enumerate(forces0):
            print(("%25.12E" * 3) % tuple(atforce))
        print("Difference between obtained and reference forces:")
        diff = forces - forces0
        for iat in range(len(forces)):
            print(("%25.12E" * 3) % tuple(diff[iat]))
        print("Max diff in any force component:")
        print("%25.12E" % (abs(diff).max(),))


def print_latderivs(latderivs, latderivs0):
    """Prints calculated lattice derivatives."""

    print("Lattice derivatives by finite differences:")
    for ii in range(3):
        print(("%25.12E" * 3) % tuple(latderivs[ii]))
    if latderivs0 is not None:
        print("Reference lattice derivatives:")
        for ii in range(3):
            print(("%25.12E" * 3) % tuple(latderivs0[ii]))
        print("Difference between obtained and reference lattice derivatives:")
        diff = latderivs - latderivs0
        for ii in range(3):
            print(("%25.12E"*3) % tuple(diff[ii]))
        print("Max diff in any lattice derivatives component:")
        print("%25.12E" % (abs(diff).max(), ))


def readgen(fname):
    """Reads in the content of a gen file."""

    fp = open(fname, "r")
    line = fp.readline()
    words = line.split()
    natom = int(words[0])
    periodic = (words[1] == 'S' or words[1] == 's')
    fractional = (words[1] == 'F' or words[1] == 'f')
    periodic = (periodic or fractional)
    line = fp.readline()
    specienames = line.split()
    coords = np.empty((natom, 3), dtype=float)
    species = np.empty(natom, dtype=int)
    for ii in range(natom):
        line = fp.readline()
        words = line.split()
        species[ii] = int(words[1]) - 1
        coords[ii] = (float(words[2]), float(words[3]), float(words[4]))
    if periodic:
        line = fp.readline()
        origin = np.array([float(s) for s in line.split()], dtype=float)
        latvecs = np.empty((3, 3), dtype=float)
        for ii in range(3):
            line = fp.readline()
            latvecs[ii] = [float(s) for s in line.split()]
        if fractional:
            coords = frac2cart(latvecs, coords)
        origin *= AA__BOHR
        latvecs *= AA__BOHR
    else:
        origin = None
        latvecs = None
    coords *= AA__BOHR
    return specienames, species, coords, origin, latvecs


def writegen(fname, data):
    """Writes the geometry as gen file."""

    fp = open(fname, "w")
    specienames, species, coords, origin, latvecs = data
    fp.write("%5d %s\n" % (len(coords), latvecs is None and "C" or "S"))
    fp.write(("%2s "*len(specienames) + "\n") % tuple(specienames))
    coords = coords * BOHR__AA
    for ii in range(len(coords)):
        fp.write("%5d %5d %23.15E %23.15E %23.15E\n"
                 % (ii + 1, species[ii] + 1, coords[ii, 0], coords[ii, 1],
                    coords[ii, 2]))
    if latvecs is not None:
        origin = origin * BOHR__AA
        latvecs = latvecs * BOHR__AA
        fp.write("%23.15E %23.15E %23.15E\n" % tuple(origin))
        for ii in range(3):
            fp.write("%23.15E %23.15E %23.15E\n" % tuple(latvecs[ii]))
    fp.close()


def cart2frac(latvecs, coords):
    "Converts cartesian coordinates to fractional coordinates."

    invlatvecs = np.empty((3, 3), dtype=float)
    invlatvecs = np.transpose(latvecs)
    newcoords = np.array(coords)
    invlatvecs = la.inv(invlatvecs)
    for iat, atcoords in enumerate(coords):
        newcoords[iat] = np.dot(invlatvecs, atcoords)
    return newcoords


def frac2cart(latvecs, coords):
    """Converts fractional coordinates to cartesian ones."""

    newcoords = np.array(coords)
    for iat, atcoords in enumerate(coords):
        newcoords[iat] = np.dot(np.transpose(latvecs), atcoords)
    return newcoords


def stress2latderivs(stress, latvecs):
    """Converts stress to lattice derivatives."""

    invlatvecs = la.inv(latvecs)
    volume = la.det(latvecs)
    latderivs = -volume * np.transpose(np.dot(stress, invlatvecs))
    return latderivs


if __name__ == "__main__":
    main()
